(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
Pytree 是 JAX 定義的資料結構,依照 JAX 官方文件 [29.1] :
a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.
舉一些 Pytrees 的例子:
# 一層的 Pytree,內含三個葉節點。
[1, 'a', object()]
# 二層的 Pytree。
(1, (2, 3), ())
# 三層的 Pytree。
[1, {'k1': 2, 'k2': (3, 4)}, 5]
值得注意的是,我們常用的 DeviceArray 資料結構,它是一個葉節點,而非容器。
另外,None 在 jax Pytree 裏被視為「空容器」,而非葉節點。
在 JAX 裏,Pytree 常常用來包裝 (1) 模型參數 model parameters ,(2) 資料集 dataset entries 和 (3) RL agent observations,使得它們便於管理及儲存。
在這裏老頭介紹幾個 Pytree 常用的 API,完整的 API 函式列表,可以參考 [29.2]。
tree_leaves () 會將 tree 中的葉節點一一取出,並將它們放置在 list 中回傳。可以把它視為壓平整個 tree 。
[按:在早期的 JAX 版本中,tree_leaves 是直接放在 jax 封裝下(jax.tree_leaves),所以在網路上比較舊的範例程式中,讀者們還可以看到這樣的呼叫方式。]
import jax.tree_util as jtree
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Let's see how many leaves they have:
for pytree in example_trees:
leaves = jtree.tree_leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
output:
前面提到,None 在 jax Pytree 裏被視為「空容器」,而非葉節點。可以用下列的程式片斷來檢測這個說法:
jtree.tree_leaves([None, None, None])
output:
[]
is_leaf= 的用法,是在變更 Pytree 判斷樹中某一節點是不是葉節點的方式。假設我們希望暫時地將 list 或 dict 視為葉結節點時,我們就可以分別定義「判別函式」,去判斷輸入的資料型態是不是 list / dict ,如果是,就回傳 True,告訴 jax.tree_util 它是葉結點。
在呼叫 tree_leaves() 時,利用 is_leaf= 指定判別函式。
# to force a container type as leaf.
# ==============================================================================
# force list to be a leaf
def check_list(x):
return isinstance(x, list)
# force dict to be a leaf
def check_dict(x):
return isinstance(x, dict)
print(jtree.tree_leaves(example_trees[2]))
print(jtree.tree_leaves(example_trees[2], is_leaf=check_list))
print(jtree.tree_leaves(example_trees[2], is_leaf=check_dict))
output:
[1, 2, 3, 4, 5]
[[1, {'k1': 2, 'k2': (3, 4)}, 5]]
[1, {'k1': 2, 'k2': (3, 4)}, 5]
tree_map 的運作方式,基本上和 Python map 一樣,套用函式 f 在所有的葉節點上,回傳一個結構相同,但包含新的值的 Pytree。
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree_map(lambda x: x*2, list_of_lists)
output:
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
函式 f 也可以被套用於多個 Pytree 的葉節點上,例如:
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x-y, list_of_lists, another_list_of_lists)
output:
[[0, 0, 0], [0, 0], [0, 0, 0, 0]]
使用時要注意,函式 f 輸入參數的數量,必須和呼叫 tree_map() 時輸入的 pytree 數量一致,而且這些輸入的 Pytree 必須要有一致的結構,否則會產生執行時錯誤。
is_leaf= 的用法,可以參考前面 tree_leaves() API。
回傳 Pytree 的結構定義,容器以 Python 的語法來顯示,葉節點則以 * 表示。
for pytree in example_trees:
struct = jtree.tree_structure(pytree)
print(f"{repr(pytree):<45} : {struct}")
output:
目前為止,對於 Pytree 的介紹可能還是太抽象了,稍後老頭會打造一個簡單的神經網路模型, 以實際的例子來說明 Pytree 的用法。
註:
[29.1] 參考 What is a pytree 。
[29.2] 可參考 jax.tree_util package。